#!/usr/bin/env python3
"""
subset_divergence_state.py

Subset divergence (items 1-40 vs 41-50) using absolute wordcount diffs |CPU-GPU|
for a chosen group (Warm/Cold/Alt). Welch t-test + Mann–Whitney U.

Usage:
  python subset_divergence_state.py --base_dir "/path/to/logs" --group Warm
"""

import argparse
import re
from pathlib import Path
import numpy as np
from scipy import stats

ITEM_SPLIT_RE = re.compile(r"(?m)^\s*(\d{1,2})\.\s")

CONDITIONS = ["BaselineNoJSON", "BaselineWithJSON", "HighRigor", "PushbackStrong"]

def trim_item50_noise(text: str) -> str:
    m = re.search(r"(?ims)\r?\n\r?\n+please note\b", text)
    if m:
        return text[:m.start()].rstrip()
    return text.rstrip()

def extract_items(text: str):
    splits = ITEM_SPLIT_RE.split(text)
    items = {}
    for i in range(1, len(splits), 2):
        num = int(splits[i])
        content = splits[i + 1].strip()
        items[num] = content
    if 50 in items:
        items[50] = trim_item50_noise(items[50])
    expected = set(range(1, 51))
    if set(items.keys()) != expected:
        raise ValueError("Bad item keys")
    return items

def mean_wordcounts(paths):
    runs = []
    for p in paths:
        text = Path(p).read_text(encoding="utf-8", errors="ignore")
        items = extract_items(text)
        runs.append({i: len(items[i].split()) for i in range(1, 51)})
    means = {}
    for i in range(1, 51):
        means[i] = float(np.mean([r[i] for r in runs]))
    return means

def build_filename(condition, hardware, tag):
    return f"{condition}_{hardware}_T0_{tag}.txt"

def files_for_group(base, cond, hw, group, warm_tags, alt_tags, cold_tag):
    if group == "Warm":
        tags = warm_tags
    elif group == "Alt":
        tags = alt_tags
    else:
        tags = [cold_tag]
    return [base / build_filename(cond, hw, t) for t in tags]

def analyze(cond, base, group, warm_tags, alt_tags, cold_tag):
    cpu_files = files_for_group(base, cond, "CPU", group, warm_tags, alt_tags, cold_tag)
    gpu_files = files_for_group(base, cond, "GPU", group, warm_tags, alt_tags, cold_tag)
    for p in cpu_files + gpu_files:
        if not p.exists():
            raise FileNotFoundError(f"Missing: {p}")

    cpu_means = mean_wordcounts(cpu_files)
    gpu_means = mean_wordcounts(gpu_files)
    div = np.array([abs(cpu_means[i] - gpu_means[i]) for i in range(1, 51)])

    g1 = div[:40]
    g2 = div[40:]

    t_stat, p_val = stats.ttest_ind(g1, g2, equal_var=False)
    u_stat, p_u = stats.mannwhitneyu(g1, g2, alternative="two-sided")

    return {
        "condition": cond,
        "group": group,
        "mean_div_1_40": float(np.mean(g1)),
        "mean_div_41_50": float(np.mean(g2)),
        "welch_t": float(t_stat),
        "welch_p": float(p_val),
        "mannwhitney_u": float(u_stat),
        "mannwhitney_p": float(p_u),
    }

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--base_dir", required=True)
    ap.add_argument("--group", choices=["Warm","Cold","Alt"], default="Warm")
    ap.add_argument("--warm_tags", nargs="+", default=["W2","W3","W4"])
    ap.add_argument("--alt_tags", nargs="+", default=["A1","A2","A3"])
    ap.add_argument("--cold_tag", default="C1")
    args = ap.parse_args()

    base = Path(args.base_dir)
    for cond in CONDITIONS:
        r = analyze(cond, base, args.group, args.warm_tags, args.alt_tags, args.cold_tag)
        print(f"\n=== {r['condition']} ({r['group']}) ===")
        print(f"Mean divergence (1–40):  {r['mean_div_1_40']:.3f}")
        print(f"Mean divergence (41–50): {r['mean_div_41_50']:.3f}")
        print(f"Welch t-test: t={r['welch_t']:.3f}, p={r['welch_p']:.6g}")
        print(f"Mann–Whitney U: U={r['mannwhitney_u']:.3f}, p={r['mannwhitney_p']:.6g}")

if __name__ == "__main__":
    main()
